# -*- coding: utf-8 -*-
import os
from typing import List

from zvdata import IntervalLevel


# kdata schema rule
# 1)name:{entity_type}{level}Kdata
# 2)one db file for one schema

def gen_quote_domain(providers: List[str], entity_type: str, levels):
    tables = []
    for level in levels:
        level = IntervalLevel(level)

        cap_entity_type = entity_type.capitalize()
        cap_level = level.value.capitalize()

        if level != IntervalLevel.LEVEL_TICK:
            kdata_common = f'{cap_entity_type}KdataCommon'
        else:
            kdata_common = f'{cap_entity_type}TickCommon'

        class_name = f'{cap_entity_type}{cap_level}Kdata'
        table_name = f'{entity_type}_{level.value}_kdata'
        tables.append(table_name)

        domain_template = f'''# -*- coding: utf-8 -*-
# this file is generated by gen_quote_domain function, dont't change it
from sqlalchemy.ext.declarative import declarative_base

from zvdata.contract import register_schema
from zvt.domain.quotes import {kdata_common}

KdataBase = declarative_base()


class {class_name}(KdataBase, {kdata_common}):
    __tablename__ = '{table_name}'


register_schema(providers={providers}, db_name='{table_name}', schema_base=KdataBase)

__all__ = ['{class_name}']
'''
        # generate the domain
        with open(os.path.join(entity_type, f'{table_name}.py'), 'w') as outfile:
            outfile.write(domain_template)

    # generate the package
    imports = [f'from zvt.domain.quotes.{entity_type}.{table} import *' for table in tables]
    imports_str = '\n'.join(imports)

    package_template = '''# -*- coding: utf-8 -*-
# this file is generated by gen_quote_domain function, dont't change it
''' + imports_str

    with open(os.path.join(entity_type, '__init__.py'), 'w') as outfile:
        outfile.write(package_template)


if __name__ == '__main__':
    # 股票行情
    gen_quote_domain(providers=['joinquant'], entity_type='stock',
                     levels=[level for level in IntervalLevel if level != IntervalLevel.LEVEL_TICK])

    # 数字货币行情
    gen_quote_domain(providers=['ccxt'], entity_type='coin',
                     levels=[level for level in IntervalLevel])

    # 板块行情
    gen_quote_domain(providers=['eastmoney'], entity_type='block',
                     levels=[IntervalLevel.LEVEL_1DAY, IntervalLevel.LEVEL_1WEEK, IntervalLevel.LEVEL_1MON])

    # etf行情
    gen_quote_domain(providers=['sina'], entity_type='etf',
                     levels=[IntervalLevel.LEVEL_1DAY])
    # 指数行情
    gen_quote_domain(providers=['sina'], entity_type='index',
                     levels=[IntervalLevel.LEVEL_1DAY])
